Autism spectrum disorders detection based on multi-task transformer neural network

Autism Spectrum Disorders (ASD) are neurodevelopmental disorders that cause people difficulties in social interaction and communication. Identifying ASD patients based on resting-state functional magnetic resonance imaging (rs-fMRI) data is a promising diagnostic tool, but challenging due to the complex and unclear etiology of autism. And it is difficult to effectively identify ASD patients with a single data source (single task). Therefore, to address this challenge, we propose a novel multi-task learning framework for ASD identification based on rs-fMRI data, which can leverage useful information from multiple related tasks to improve the generalization performance of the model. Meanwhile, we adopt an attention mechanism to extract ASD-related features from each rs-fMRI dataset, which can enhance the feature representation and interpretability of the model. The results show that our method outperforms state-of-the-art methods in terms of accuracy, sensitivity and specificity. This work provides a new perspective and solution for ASD identification based on rs-fMRI data using multi-task learning. It also demonstrates the potential and value of machine learning for advancing neuroscience research and clinical practice.


Introduction
ASD (Autism Spectrum Disorders) is a heterogeneous condition that affects communication, behavior, and social interactions in various ways and degrees [1].According to the latest Diagnostic and Statistical Manual of Mental Disorders (DSM-5), ASD encompasses a spectrum of disorders that were previously diagnosed separately, such as autism, Asperger's syndrome, and other pervasive developmental disorders.The global prevalence of ASD has increased dramatically over the years, reaching 1 in 59 children in the United States in 2014 [2].ASD poses a major public health challenge, as it impacts not only the individuals with ASD, but also their families and society [3].Early diagnosis and intervention are crucial for improving the outcomes and reducing the costs of ASD [4], but the current standard diagnosis relies on subjective and time-consuming assessments by multidisciplinary teams using standardized tools [5].These assessments require highly specialized knowledge and experience from the evaluators, and are often inaccessible or unavailable to many patients [6].Therefore, there is an urgent need for objective and efficient diagnostic methods based on biological markers.
With the rapid development of AI (Artificial Intelligence) technology, machine learning as a subfield of AI, it has largely enhanced the role of computational methods in neuroscience [7].Machine learning has been successfully applied in Alzheimer's disease, mild cognitive impairment [8,9], temporal lobe epilepsy, schizophrenia, Parkinson's [10], dementia [11,12], ADHD [13,14], ASD [15,16] and major depressive disorder [17].In particular, the identification of ASD has made great progress and a series of effective methods have been developed [18].These methods can be briefly divided into two categories as follows: (1) Based on traditional machine learning methods, it models ASD data as a binary classification problem using traditional machine learning techniques.Crippa et al. [19] used support vector machine (SVM) algorithm to segment ASD patient samples and normal controls (NC) samples by fitting a hyperplane.Rane et al. [20] used logistic regression method to predict ASD diagnosis by transforming fMRI data into probabilities of specific binary values through linear operations.Abbas et al. [21] used an integrated learning approach to construct an ASD screening tool by combining a parent questionnaire-based classifier and a behavioral video-based classifier; (2) Based on the deep learning approach, it uses deep neural networks to extract hidden features in ASD data for ASD identification.Heinsfeld et al. [22] used autoencoders to downscale rs-fMRI data and then used deep neural networks for ASD prediction.Alsaade et al. [23] performed prediction of ASD disease by constructing a functional brain connectivity matrix and projecting it to a deep feature space.Pavăl [24] used convolutional neural networks for facial abnormality identification in ASD patients.
Despite the success of these methods, the identification of ASD remains a challenge due to the complex causes of autism formation and unclear pathogenesis [25,26].Moreover, most existing methods are based on single-task learning, which ignores the potential correlations and complementarities among different ASD recognition tasks [27,28].To address these issues, we propose a novel multi-task learning framework for ASD identification based on resting-state functional magnetic resonance imaging (rs-fMRI) data.Figure 1 is the multitasking transformer framework diagram.Rs-fMRI is increasingly used to study neural connectivity and identify biomarkers of psychiatric disorders.It performs imaging based on blood oxygen level-dependent (BOLD) signal changes in brain regions in a non-invasive manner.Thus rs-fMRI-based ASD identification can provide more accurate, stable and interpretable predictions.
The main contributions and novelties of our work are as follows: This paper proposes a novel multi-task learning framework for ASD identification based on rs-fMRI data, which can leverage useful information from multiple related tasks to improve the generalization performance of the model.We introduce a temporal encoding module to encode the rs-fMRI data, which can capture the sequential information embedded in the temporal nodes.Meanwhile, we adopt an attention mechanism to extract ASD-related features from each rs-fMRI dataset, which can enhance the feature representation and interpretability of the model.
We design a feature sharing module to share the ASD features learned from each dataset, which can exploit the correlations and complementarities among different tasks.
We conduct extensive experiments on two public rs-fMRI datasets to evaluate the effectiveness of our proposed method.The results show that our method outperforms state-of-the-art methods in terms of accuracy, sensitivity and specificity.This work provides a

Materials
In the present study, we used rs-fMRI data from the Autism Imaging Data Exchange (ABIDE).Due to the limited number of subjects at the site, we selected 2 different sites (number of subjects > 100) from a large number of sites, including UM and NYU.Also, data and detailed information are available at https://fcon_1000.projects.nitrc.org/indi/abide/,where Table 1 shows the demographic information of the subjects aggregated.

Data preprocessing
There is no consensus on the best methods for preprocessing resting state fMRI data.Rather than being prescriptive and favoring a single processing strategy, we have preprocessed the data using Connectome Computation System (CCS), Configurable Pipeline for the Analysis of Connectomes (CPAC), Data Processing Assistant for Resting-State fMRI (DPARSF), Neuroimaging Analysis Kit (NIAK), each of which was implemented using the chosen parameters and settings of the pipeline developers.
The preprocessing steps implemented by the different pipelines are very similar.The largest changes are for the specific algorithms used for each step, their software implementations, and the parameters used.The following sections outline the different preprocessing steps and their differences in the pipeline.

Basic processing
Step For strategies that include global signal correction, the global mean signal was included with nuisance variable regression.Band-pass filtering (0.01-0.1 Hz) was applied after nuisance variable regression.

Registration
A transform from original to template (MNI152) space was calculated for each dataset from a combination of functional-to-anatomical and anatomical-to-template transforms.The anatomical-to-template transforms were calculated using a two step procedure that involves (one or more) linear transform that is later refined with a very high dimensional non-linear transform.When data are written into template space (typically after the calculation of derivatives, except for NIAK) all transforms are used simultaneously to avoid multiple interpolations.

Methods
In this section, we design the multitask Transformer framework to improve the ASD prediction performance by sharing the knowledge learned from multiple tasks.Specifically, Section"Problem definition" formally defines the problem.Section"Location coding" describes how to encode positions according to the order of time nodes.Section"Attention module" defines the way the attention mechanism in the Transformer captures useful features.
Section"Feature sharing" describes the process of feature sharing among different tasks.Section"Objective function" defines the objective function for Optimization of the objective function.

Problem definition
In this section, we describe the proposed multi-task Transformer learning framework.Suppose we have D tasks and the rs-fmri dataset as follows (1), An instance as follows (2) in X d contains T time nodes and N brain regions, and the corresponding label y d ∈ {0, 1} is a binary classification task, In the experiment, label 1 represents illness, label 0 represents no disease, and the label and label of both tasks have the same significance.We further assume that there are D different Transformer networks, and each Transformer network consists of L-layer feedforward networks, where the lth layer network extracts the features of task d through f l d ∈ R T×N .Specifically, our goal is to improve the generalization performance of task d by sharing features learned from other tasks as follows (3). (1)

Location coding
Temporal order information in time series data helps to improve model prediction accuracy [29].To take full advantage of the sequential information embedded in the time nodes in the rs-fMRI data, we inject information about the position in the time node sequence for each input data.Specifically, we obtain a position-encoded PE with the same dimensionality as x d using the sine and cosine function, which is calculated as follows (4,5).
where t denotes the position of the time node in T time nodes, 2n denotes the brain region of even number, and 2n + 1 denotes the brain region of base number.In our model, x d represents the input features of the task d.It is a two-dimensional matrix, where each row corresponds to a time node and each column corresponds to a brain region.Thus, the dimension of x d is (T, R), where T is the number of time nodes and R is the number of brain regions.Then, the location information embedding, which we implement by summing the location encoding PE and x d , is calculated as follows (6).

Attention module
The Transformer network consists of several attention modules to improve the ASD prediction performance, as shown in Fig. 2. In our model, Q, K, and V represent the query (Query), the Key (Key), and the value (Value), respectively.These manipulations are central parts of the attention mechanism and are used to compute correlations between input features.f l d represents the l-layer feature of the task d.These features are extracted through feed-forward networks and attention modules and can capture important information about the input data.Our goal is to improve the generalization performance of task d by sharing features learned from other tasks.
Specifically, the features of f l d are first extracted by three different linear operations of Q, K and V, and the number of channels is reduced to half of the original one to reduce the computational effort, as follows (7-9). (2) denote the output feature vectors.Then, we matrix multiply W Q and W K to calculate the correlation weights between time nodes and score them using softmax operation.Finally, we weight sum the correlation weights and W V to obtain the attention feature vector, which is calcu- lated as follows (10).
where f l d ∈ R T×N denotes the attention feature vector.Finally, we fuse the attentional feature vector f l d and the feature f l d , which aims to compensate for the information lost when the attentional mechanism captures the features, calculated as follows (11).
where f l d ∈ R T×N denotes the output fused features.In addition, each layer of the feedforward network consists of an attention module and a Forward network, and the fused features are transformed into the task-specific feature space by a fully connected Forward network, calculated as follows (12).

Feature sharing
To realize the interaction of features between tasks, we build a feature sharing module, as shown in Fig.  where f l+1 d denotes the output of the l + 1-layer network.We can identify specific layer tasks by setting M i,k , i < D, k < D to zero, or share more features by assigning them higher values.

Objective function
Feedforward networks do not change the dimensionality of the feature vectors; however, high-dimensional and high-noise data have a negative impact on the prediction performance.To solve this problem, we reduce the dimensionality of the feature vector f L d ∈ R T×N by FC Layers and perform the prediction.FC Layers consists of three layers of fully connected operations, the first two layers are used to reduce the dimensionality and the last layer gets the prediction output, which is calculated as follows ( 14): where y d denotes the output, and W i=1,2,3 and b i=,1,2,3 denote the corresponding parameters.Then, we use the binary cross-entropy as the loss and the objective function is calculated as follows (15).

Results and discussion
In this section, we conduct extensive experiments to verify the effectiveness of our approach.Specifically, Section"Experimental setup" describes the experimental setting and setup.Section"Evaluation metrics" gives the evaluation metrics to evaluate the experimental results.Section"Experimental results and discussion" presents the comparison of our method with the current popular methods on two ASD datasets and the analysis of the experimental results.

Experimental setup
The experiments are programmed and implemented as follows: PyTorch 1.9, Python 3.8, using a GeForce RTX 3090 GPU for training.With grid search method for tuning hyperparameters, we use Adam as the training optimizer with 120 iterations, an initial learning rate of 1 × 10 −5 , 50% decay every 30 iterations, and a Batch size of 16.The number of feedforward network layers L is 5, ( 14)

Evaluation metrics
We used Accuracy, Sensitivity and Specificity as metrics to evaluate the ASD identification results.All methods are tested using these metrics and calculated as follows (16)(17)(18): where True Positive indicates the number of ASD-positive patients correctly classified, True Negative indicates the number of ASD-false-negative patients, False Positive indicates the number of ASD-false-positive patients, and False Negative indicates the number of ASD-negative patients correctly classified.

Effects of loss function
This experiment established two datasets.Figure 4   From the figure, we can observe that (1) the TN value is the largest among the four values, i.e., the number of correctly predicted NC samples is the largest.Meanwhile, FP is the smallest, i.e., the number of incorrectly predicted NC samples is the least.This again validates that our method has a low misdiagnosis rate; (2) TP indicates the number of samples that correctly identified ASD patients.The difference between the TP and FN values is not significant.The reason for this result is that the number of NC samples in the training sample is high, which leads to category imbalance and thus affects the ability of the model to identify ASD patients; and (3) In the confusion matrix corresponding to the two datasets, the proportions of TN, FP, FN and TP are similar, which proves that the model has some generalizability.In summary, our method can identify NC patients well and has some ability to identify ASD patients.

Ablation studies
As shown in Fig. 1, Multitasking transformer framework diagram can be regarded as a federated network composed of multiple features share modules.In this section, we conduct ablation studies to verify the effectiveness of the crucial components in multi-task learning framework and evaluate the impact of each single task network on the results.The transformer network consists of several attention modules to improve the ASD prediction performance.Based on the transformer network, we built a single network and a feature sharing module respectively.All experiments are performed with the same hyperparameter configuration.Table 2 shows the ablation studies with different network configurations.
From index in Table 2, we can see that when we simply add a single task network to the transformer network, Accuracy and Specificity all suffer a decline, but the sensitivity suffers a rise.This shows that adding the multitask will bring better results.When we applied Single task network and Multitask network to the NYU dataset, the accuracy, sensitivity, and specificity indicators of Single task network were 63.15%, 52.63%, and 73.68%, respectively.The accuracy, sensitivity, and specificity indicators of Multitask network were 67.64%, 40.00%, and 89.47%, respectively.When we applied Single task network and Multitask network to the UM dataset, the accuracy, sensitivity, and specificity indicators of Single task network were 70.68%, 66.00%, and 73.68%, respectively, while the accuracy, sensitivity, and specificity indicators of Multitask network were 72.00%, 55.00%, and 86.66%, respectively.In summary, adding the feature sharing module to the transformer network has the best recognition and prediction performance for rs-fMRI data, indicating the necessity of the feature sharing module in deep learning networks.

Comparison with the state-of-the-art methods
In this section, we compare the proposed method with some popular machine learning and deep learning methods, including support vector machines [30], random forests [31], multilayer perceptron [32], SAENet [33], MLwSGSU [34] and MCNNet [35].To test the results of these methods, we used their public codes on the NYU and UM datasets for training and evaluation.The experimental results of the seven models on the two datasets are shown in Figs. 6 and 7.The figures show that compared with other methods, the accuracy values obtained by us have better results.Figures 6 and 7 show that compared with the suboptimal method, the accuracy values obtained by us have increased by 4.54% and 5.88% on the two data sets respectively.
Table 3 shows the results of the different methods on the NYU and UM datasets.On the NYU dataset, our proposed model achieves 67.64%, 40% and 89.47% in accuracy, sensitivity and specificity respectively, which are the best results among all compared methods.Unlike the case on the NYU dataset, on the UM dataset, our proposed model achieves 72%, 50% and 86.66% in accuracy, sensitivity and specificity respectively.Through qualitative comparisons on the two datasets, we find that both our model can guarantee the improvement of comprehensive performance and maintain a high specificity without introducing too many false positives.Therefore, compared to other methods, we believe that Multitasking Transformer framework can better cope with the ASD prediction.
To sum up, (1) Multi-task learning methods are competitive with traditional machine learning methods and deep learning methods in ASD recognition; (2) Our methods are significantly better than other methods in both accuracy and specificity; and (3) Our methods are not as sensitive as other methods.We hypothesize that since the number of NC patients in the dataset is slightly more than that of ASD patients, the attention mechanism when training the model is more biased towards learning to capture NC features, thus negatively influencing the extraction of ASD features, and therefore less sensitive.In addition, the method is effective for ASD identification.In conclusion, our method can better identify ASD patients with a lower probability of misdiagnosis of NC patients.

Limitations
Although our method performs very well compared to other methods, several limitations exist.Although our method has higher accuracy and specificity, there is still lower sensitivity.And the accuracies are only 67.64% and 72%.This is attributed to the amount of training data being too small, which leads too poor generalization on the rs-fMRI data.We plan to explore more effective data augmentation techniques in future work.

Fig. 1
Fig. 1 Multitasking transformer framework diagram where f l d ∈ R T×N denotes the output, Relu(•) denotes the activation function, and W f and b f denote the corre- sponding parameters.
3. Each layer of the network defines D learnable activation mappings M D = {M d } D d=1 , where M d = {M 1d , ..., M Dd } .We use M D to linearly combine the feature vectors of different task networks and use them as inputs for the next layer of feedforward networks.Specifically, we matrix the activation mapping M D = M 11 ... M D1... M dd ... M 1D ... M DD and use it to linearly combine multiple feature vectors, which are computed as follows(13).

Fig. 2 Fig. 3
Fig. 2 Attention module d y d ∈Y d −y d log y d and the three fully-connected layers in FC Layers have output dimensions of 4096, 2048, and 2. In addition, we divide the ASD data in Section"Materials and methods" randomly into a training set and a test set in the ratio of 8:2 ratio randomly into training set and test set for subsequent experiments.
shows the loss plot lines during the training of the experiments.For the Fig. 4a and b loss function, the loss value curve has fluctuated several times in a large range during the training process, which may indicate the occurrence of gradient explosion, resulting in excessive weight update of the model, thus causing instability of the model.Therefore, we need multiple training to improve the stability of training.From the graphs, the following conclusions can be drawn: (1) The experiments have converged for both datasets and the experimental results are reliable; (2) The experiments both have the fastest rate of decline until 90 iterations.This indicates that the model was able to effectively learn how to classify ASD patients and NC patients during this time; and (3) The experiments both reached convergence at 105-120 iterations, and the model was able to fit the training data.In summary, our model can fit the ASD dataset well and the experimental results are reliable and valid.

Figure 5
Figure5is a confusion matrix showing the number of true negative (TN), false positive (FP), false negative (FN) and true positive (TP) samples.From the figure, we can observe that (1) the TN value is the largest among the four values, i.e., the number of correctly predicted NC samples is the largest.Meanwhile, FP is the smallest, i.e., the number of incorrectly predicted NC samples is the least.This again validates that our method

Fig. 4 Fig. 5
Fig. 4 Training loss plots.a NYU corresponding loss plot.b NYU corresponding loss plot

Conclusion
In this study, we propose the multi-task Transformer network, which are essential for predicting and diagnosing ASD diseases.The proposed network utilizes multi-task learning and attention mechanisms for ASD recognition and achieves excellent classification performance on NYU and UM ASD datasets.In addition, the attention mechanism enhances the model's attention to ASDrelated features.Multi-task learning enhances the model generalization performance by fusing knowledge learned from different ASD datasets.We evaluated our method

Fig. 6 Fig. 7
Fig. 6 Accuracy comparison of different network models in NYU dataset Each pipeline implemented some form of nuisance variable regression to clean confounding variation due to physiological processes (heart beat and respiration), head motion, and low frequency scanner drifts, from the fMRI signal.

Table 1
Demographic information of subjects

Table 2
Ablation studies with different network configurations

Table 3
Experimental results of different methods